import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer,MLA
from layers.Embed import PatchEmbedding
from collections import Counter
from layers.SharedWavMoE import WavMoE
from layers.RevIN import RevIN
import torch.fft
from layers.Embed import DataEmbedding

class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        # self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        # x = self.flatten(x)
        print(self.linear,x.shape)
        x = self.linear(x)
        x = self.dropout(x)
        return x



class Model(nn.Module):
    """
   """

    def __init__(self, configs):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.patch_len = 16
        self.stride = 8
        # embedding configs
        self.output_attention = configs.output_attention
        self.padding = self.stride
        # MoE设置
        self.hidden_size = configs.hidden_size
        self.intermediate_size = configs.intermediate_size
        self.top_k = configs.top_k
        self.shared_experts = configs.shared_experts
        self.wavelet = configs.wavelet
        self.level = configs.shared_experts
        self.proj_wight = configs.proj_wight
        # Embedding
        self.patch_embedding = PatchEmbedding(
            configs.d_model, self.patch_len, self.stride, self.padding, configs.dropout)

        self.data_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
                                                    configs.dropout)

        self.revin_layer = RevIN(configs.enc_in)
        self.encoder1 = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor,
                                      attention_dropout=configs.dropout,
                                      output_attention=configs.output_attention),
                        configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        self.encoder2 = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor,
                                      attention_dropout=configs.dropout,
                                      output_attention=configs.output_attention),
                        configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        self.head_nf = configs.d_model * \
                       int((configs.seq_len - self.patch_len) / self.stride + 2)
        self.projection = nn.Linear(self.head_nf, int(configs.seq_len*self.proj_wight), bias=True)

        self.data_projection = nn.Linear(configs.d_model, configs.enc_in, bias=True)
        self.wavmoe = WavMoE(configs)
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            self.head = FlattenHead(configs.enc_in, nf= int(configs.pred_len*self.proj_wight), target_window= self.pred_len,
                                    head_dropout=configs.dropout)
        elif self.task_name == 'anomaly_detection':
            self.head = FlattenHead(configs.enc_in, nf= int(configs.pred_len*self.proj_wight), target_window= self.pred_len,
                                    head_dropout=configs.dropout)
        self.gelu = nn.GELU()

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        # 归一化并且嵌入
        x_revin = self.revin_layer(x_enc, 'norm').permute(0, 2, 1)
        print(x_revin.shape)

        # 进入注意力机制
        x_inver = self.data_embedding(x_revin.permute(0, 2, 1), x_mark_enc)
        nav_out, attn_w = self.encoder2(x_inver, attn_mask=None)
        print("nav_out.shape:", nav_out.shape, self.data_projection)
        nav_out = self.data_projection(nav_out)
        print("nav_out.shape:", nav_out.shape)

        # patch embedding进入多头FullAttention
        B, D, S = x_revin.shape
        # u: [bs * nvars x patch_num x d_model]
        x_pe, n_vars = self.patch_embedding(x_revin)
        print(x_pe.shape, n_vars)
        enc_out, attn = self.encoder1(x_pe)
        dec_out = enc_out.reshape(B, D, -1)
        print(dec_out.shape, self.head_nf)
        act_val = self.projection(dec_out)
        print("act_val:", act_val.shape)

        # 专家系统
        moe_out, router_logits = self.wavmoe(act_val + nav_out.permute(0, 2, 1))
        print("moe_out", moe_out.shape)
        head_out = self.head(moe_out)
        x_out = self.revin_layer(head_out.permute(0, 2, 1), 'denorm')
        print(x_out.shape)
        return x_out


    def anomaly_detection(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        # 归一化并且嵌入
        x_revin = self.revin_layer(x_enc, 'norm').permute(0, 2, 1)
        print(x_revin.shape)

        # 进入注意力机制
        x_inver=self.data_embedding(x_revin.permute(0, 2, 1), x_mark_enc)
        nav_out, attn_w = self.encoder2(x_inver, attn_mask=None)
        print("nav_out.shape:", nav_out.shape,self.data_projection)
        nav_out = self.data_projection(nav_out)
        print("nav_out.shape:", nav_out.shape)

        #patch embedding进入多头FullAttention
        B, D, S = x_revin.shape
        # u: [bs * nvars x patch_num x d_model]
        x_pe, n_vars = self.patch_embedding(x_revin)
        print(x_pe.shape, n_vars)
        enc_out, attn = self.encoder1(x_pe)
        dec_out = enc_out.reshape(B, D, -1)
        print(dec_out.shape, self.head_nf)
        act_val = self.projection(dec_out)
        print("act_val:", act_val.shape)

        # 专家系统
        moe_out, router_logits = self.wavmoe(act_val + 0.01*nav_out.permute(0, 2, 1))
        print("moe_out", moe_out.shape)
        head_out = self.head(moe_out)
        x_out = self.revin_layer(head_out.permute(0, 2, 1), 'denorm')
        print(x_out.shape)
        return x_out


    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out[:, -self.pred_len:, :]  # [B, L, D]
        if self.task_name == 'anomaly_detection':
            dec_out = self.anomaly_detection(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out  # [B, L, D]
        return None

